Align low-resolution brain surfaces of 2 individuals with fMRI data#

In this example, we align 2 low-resolution left hemispheres using 4 fMRI feature maps (z-score contrast maps).

import pickle

import gdist
import matplotlib as mpl
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from fugw.mappings import FUGW
from mpl_toolkits.axes_grid1 import make_axes_locatable
from nilearn import datasets, image, plotting, surface

Let’s download 5 volumetric contrast maps per individual using nilearn’s API. We will use the first 4 of them to compute an alignment between the source and target subjects, and use the left-out contrast to assess the quality of our alignment.

n_subjects = 2

contrasts = [
    "sentence reading vs checkerboard",
    "sentence listening",
    "calculation vs sentences",
    "left vs right button press",
    "checkerboard",
]
n_training_contrasts = 4

brain_data = datasets.fetch_localizer_contrasts(
    contrasts,
    n_subjects=n_subjects,
    get_anats=True,
)

source_imgs_paths = brain_data["cmaps"][0 : len(contrasts)]
target_imgs_paths = brain_data["cmaps"][len(contrasts) : 2 * len(contrasts)]
Dataset created in /github/home/nilearn_data/brainomics_localizer

Downloading data from https://osf.io/hwbm2/download ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27cd441c5b4a001aa08008/ ...
 ...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d27c03e45253a001c3e189f/ ...
 ...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d27bfd0114a420016057cba/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27cb281c5b4a001aa07e29/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27cc0845253a001c3e22bd/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d10b114a420019044ed8/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d89d1c5b4a001d9f5e6e/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d429a26b340017083380/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27ddc91c5b4a001b9ef9d0/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d27d14f114a420019044efc/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d275eb845253a001c3dbf76/ ...

Downloaded 966656 of 14012301 bytes (6.9%,   13.5s remaining) ...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d275ede1c5b4a001aa00c26/ ...

Downloaded 966656 of 13951266 bytes (6.9%,   13.5s remaining) ...done. (3 seconds, 0 min)
Downloading data from https://osf.io/download/5d27037f45253a001c3d4563/ ...
 ...done. (2 seconds, 0 min)
Downloading data from https://osf.io/download/5d7b8948fcbf44001c44e695/ ...
 ...done. (1 seconds, 0 min)
/usr/local/lib/python3.8/site-packages/nilearn/datasets/func.py:763: UserWarning:

`legacy_format` will default to `False` in release 0.11. Dataset fetchers will then return pandas dataframes by default instead of recarrays.

Here is what the first contrast map of the source subject looks like (the following figure is interactive):

contrast_index = 0
plotting.view_img(
    source_imgs_paths[contrast_index],
    brain_data["anats"][0],
    title=f"Contrast {contrast_index} (source subject)",
    opacity=0.5,
)
/usr/local/lib/python3.8/site-packages/nilearn/_utils/niimg.py:63: UserWarning:

Non-finite values detected. These values will be replaced with zeros.

/usr/local/lib/python3.8/site-packages/numpy/core/fromnumeric.py:784: UserWarning:

Warning: 'partition' will ignore the 'mask' of the MaskedArray.


Computing feature arrays#

Let’s project these 4 maps to a mesh representing the cortical surface and aggregate these projections to build an array of features for the source and target subjects. For the sake of keeping the training phase of our mapping short even on CPU, we project these volumetric maps on a very low-resolution mesh made of 642 vertices.

fsaverage3 = datasets.fetch_surf_fsaverage(mesh="fsaverage3")
def load_images_and_project_to_surface(image_paths):
    """Util function for loading and projecting volumetric images."""
    images = [image.load_img(img) for img in image_paths]
    surface_images = [
        np.nan_to_num(surface.vol_to_surf(img, fsaverage3.pial_left))
        for img in images
    ]

    return np.stack(surface_images)


source_features = load_images_and_project_to_surface(source_imgs_paths)
target_features = load_images_and_project_to_surface(target_imgs_paths)
source_features.shape
/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice

/usr/local/lib/python3.8/site-packages/nilearn/surface/surface.py:465: RuntimeWarning:

Mean of empty slice


(5, 642)

Here is a figure showing the 4 projected maps for each of the 2 individuals:

def plot_surface_map(surface_map, cmap="coolwarm", colorbar=True, **kwargs):
    """Util function for plotting surfaces."""
    plotting.plot_surf(
        fsaverage3.pial_left,
        surface_map,
        cmap=cmap,
        colorbar=colorbar,
        bg_map=fsaverage3.sulc_left,
        bg_on_data=True,
        darkness=0.5,
        **kwargs,
    )


fig = plt.figure(figsize=(3 * n_subjects, 3 * len(contrasts)))
grid_spec = gridspec.GridSpec(len(contrasts), n_subjects, figure=fig)

# Print all feature maps
for i, contrast_name in enumerate(contrasts):
    for j, features in enumerate([source_features, target_features]):
        ax = fig.add_subplot(grid_spec[i, j], projection="3d")
        plot_surface_map(
            features[i, :], axes=ax, vmax=10, vmin=-10, colorbar=False
        )

    # Add labels to subplots
    if i == 0:
        for j in range(2):
            ax = fig.add_subplot(grid_spec[i, j])
            ax.axis("off")
            ax.text(0.5, 1, f"sub-0{j}", va="center", ha="center")

    ax = fig.add_subplot(grid_spec[i, :])
    ax.axis("off")
    ax.text(0.5, 0, contrast_name, va="center", ha="center")

# Add colorbar
ax = fig.add_subplot(grid_spec[2, :])
ax.axis("off")
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="2%")
fig.add_axes(cax)
fig.colorbar(
    mpl.cm.ScalarMappable(
        norm=mpl.colors.Normalize(vmin=-10, vmax=10), cmap="coolwarm"
    ),
    cax=cax,
)

plt.show()
plot 1 aligning brain dense

Computing geometry arrays#

Now we compute the kernel matrix of distances between vertices on the cortical surface. Note that in this example, we are using the same mesh for the source and target individuals, but this does not have to be the case in general.

def compute_geometry_from_mesh(mesh_path):
    """Util function to compute matrix of geodesic distances of a mesh."""
    (coordinates, triangles) = surface.load_surf_mesh(mesh_path)
    geometry = gdist.local_gdist_matrix(
        coordinates.astype(np.float64), triangles.astype(np.int32)
    ).toarray()

    return geometry


fsaverage3_pial_left_geometry = compute_geometry_from_mesh(
    fsaverage3.pial_left
)
source_geometry = fsaverage3_pial_left_geometry
target_geometry = fsaverage3_pial_left_geometry
source_geometry.shape
(642, 642)

Each line vertex_index of the geometry matrices contains the anatomical distance (here in millimeters) from vertex_index to all other vertices of the mesh.

vertex_index = 4

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.set_title("Geodesic distance in mm\non the cortical surface")
plot_surface_map(
    source_geometry[vertex_index, :],
    cmap="magma",
    cbar_tick_format="%.2f",
    axes=ax,
)
plt.show()
Geodesic distance in mm on the cortical surface

Normalizing features and geometries#

Features and geometries should be normalized before we can train a mapping. Indeed, without this scaling, it’s unclear whether the source and target features would be comparable. Moreover, the hyper-parameter alpha would depend on the scale of the respective matrices. Finally, it can empirically lead to having nan values in the computed transport plan.

Training the mapping#

Let’s create our mapping. We set alpha=0.5 to indicate that we are as interested in matching vertices with similar features as we are in preserving the anatomical geometries of the source and target subjects. We leave rho to its default value, and finally set a value of eps which is low enough for the computed transport plan to not be too regularized. High values of eps lead to faster computations and more regularized (ie blurry) plans. Low values of eps lead to solwer computations, but finer-grained plans. Note that this package is meant to be used with GPUs ; fitting mappings on CPUs in about 100x slower.

alpha = 0.5
rho = 1
eps = 1e-4
mapping = FUGW(alpha=alpha, rho=rho, eps=eps)

Let’s fit our mapping! 🚀

Remember to use the training maps only. Moreover, we limit the number of block-coordinate-descent iterations to 3 in order to limit computation time for this example.

_ = mapping.fit(
    source_features_normalized[:n_training_contrasts],
    target_features_normalized[:n_training_contrasts],
    source_geometry=source_geometry_normalized,
    target_geometry=target_geometry_normalized,
    solver="sinkhorn",
    solver_params={
        "nits_bcd": 3,
    },
    verbose=True,
)
[14:36:54] BCD step 1/3    FUGW loss:      0.02978161722421646      dense.py:464


[14:37:25] BCD step 2/3    FUGW loss:      0.00562211824581027      dense.py:464


[14:37:56] BCD step 3/3    FUGW loss:      0.005594733636826277     dense.py:464

Here is the evolution of the FUGW loss during training, with and without the regularized term:

fig, ax = plt.subplots(figsize=(10, 4))
ax.set_title(
    "Sinkhorn mapping training loss\n"
    f"Total training time = {mapping.loss_times[-1]:.1f}s"
)
ax.set_ylabel("Loss")
ax.set_xlabel("BCD step")
ax.stackplot(
    mapping.loss_steps,
    [
        (1 - alpha) * np.array(mapping.loss["wasserstein"]),
        alpha * np.array(mapping.loss["gromov_wasserstein"]),
        rho * np.array(mapping.loss["marginal_constraint_dim1"]),
        rho * np.array(mapping.loss["marginal_constraint_dim2"]),
        eps * np.array(mapping.loss["regularization"]),
    ],
    labels=[
        "wasserstein",
        "gromov_wasserstein",
        "marginal_constraint_dim1",
        "marginal_constraint_dim2",
        "regularization",
    ],
    alpha=0.8,
)
ax.legend()
plt.show()
Sinkhorn mapping training loss Total training time = 80.9s

Note that we used the sinkhorn solver here because it’s well known in the optimal transport community, but that this library comes with other solvers which are, in most cases, much faster. Let’s retrain our mapping using the mm solver, which implements a maximize-minimization approach to approximate a solution and is used by default in fugw.mappings:

mm_mapping = FUGW(alpha=alpha, rho=rho, eps=eps)

_ = mm_mapping.fit(
    source_features_normalized[:n_training_contrasts],
    target_features_normalized[:n_training_contrasts],
    source_geometry=source_geometry_normalized,
    target_geometry=target_geometry_normalized,
    solver="mm",
    solver_params={
        "nits_bcd": 5,
        "tol_bcd": 1e-10,
        "tol_uot": 1e-10,
    },
    verbose=True,
)
[14:38:02] BCD step 1/5    FUGW loss:      0.036940667778253555     dense.py:464


[14:38:07] BCD step 2/5    FUGW loss:      0.015034018084406853     dense.py:464


[14:38:13] BCD step 3/5    FUGW loss:      0.007663070689886808     dense.py:464


[14:38:19] BCD step 4/5    FUGW loss:      0.006632883567363024     dense.py:464


[14:38:31] BCD step 5/5    FUGW loss:      0.0062191360630095005    dense.py:464

And now with the ibpp solver:

ibpp_mapping = FUGW(alpha=alpha, rho=rho, eps=eps)

_ = ibpp_mapping.fit(
    source_features_normalized[:n_training_contrasts],
    target_features_normalized[:n_training_contrasts],
    source_geometry=source_geometry_normalized,
    target_geometry=target_geometry_normalized,
    solver="ibpp",
    solver_params={
        "nits_bcd": 5,
        "tol_bcd": 1e-10,
        "tol_uot": 1e-10,
    },
    verbose=True,
)
[14:38:37] BCD step 1/5    FUGW loss:      0.034542541950941086     dense.py:464


[14:38:42] BCD step 2/5    FUGW loss:      0.00773228844627738      dense.py:464


[14:38:52] BCD step 3/5    FUGW loss:      0.006187522318214178     dense.py:464


[14:39:10] BCD step 4/5    FUGW loss:      0.0058647324331104755    dense.py:464


[14:39:29] BCD step 5/5    FUGW loss:      0.0057375566102564335    dense.py:464

Computed mappings can easily be saved on disk and loaded back:

# Save mappings
with open("./mapping.pkl", "wb") as f:
    pickle.dump(mapping, f)

# Load mappings
with open("./mapping.pkl", "rb") as f:
    mapping = pickle.load(f)

Here is the evolution of the FUGW loss during training, without the regularized term. Note how, in this case, even though mm and ibpp needed more block-coordinate-descent steps to converge, they were about 2 to 3 times faster to reach the same final FUGW training loss as sinkhorn. You might want to tweak solver parameters like nits_bcd and nits_uot to get the fastest convergence rates.

fig = plt.figure(figsize=(4 * 2, 4))
fig.suptitle("Training loss comparison\nSinkhorn vs MM vs IBPP")

ax = fig.add_subplot(121)
ax.set_ylabel("Loss")
ax.set_xlabel("BCD step")
ax.plot(mapping.loss_steps, mapping.loss["total"], label="Sinkhorn FUGW loss")
ax.plot(mm_mapping.loss_steps, mm_mapping.loss["total"], label="MM FUGW loss")
ax.plot(
    ibpp_mapping.loss_steps, ibpp_mapping.loss["total"], label="IBPP FUGW loss"
)
ax.legend()

ax = fig.add_subplot(122)
ax.set_ylabel("Loss")
ax.set_xlabel("Time (in seconds)")
ax.plot(mapping.loss_times, mapping.loss["total"], label="FUGW loss")
ax.plot(mm_mapping.loss_times, mm_mapping.loss["total"], label="MM FUGW loss")
ax.plot(
    ibpp_mapping.loss_times, ibpp_mapping.loss["total"], label="IBPP FUGW loss"
)
ax.legend()

fig.tight_layout()
plt.show()
Training loss comparison Sinkhorn vs MM vs IBPP

Using the computed mapping#

The computed mapping is stored in mapping.pi as a torch.Tensor. In this example, the transport plan is small enough that we can display it altogether.

pi = mapping.pi.numpy()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_title("Transport plan", fontsize=20)
ax.set_xlabel("target vertices", fontsize=15)
ax.set_ylabel("source vertices", fontsize=15)
im = plt.imshow(pi, cmap="viridis")
plt.colorbar(im, ax=ax, shrink=0.8)
plt.show()
Transport plan

Each line vertex_index of the computed mapping can be interpreted as a probability map describing which vertices of the target should be mapped with the source vertex vertex_index.

probability_map = pi[vertex_index, :] / np.linalg.norm(pi[vertex_index, :])

fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111, projection="3d")
ax.set_title(
    "Probability map of target vertices\n"
    f"being matched with source vertex {vertex_index}"
)
plot_surface_map(probability_map, cmap="viridis", axes=ax)
plt.show()
Probability map of target vertices being matched with source vertex 4

Using mapping.transform(), we can use the computed mapping to transport any collection of feature maps from the source anatomy onto the target anatomy. Note that, conversely, mapping.inverse_transform() takes feature maps from the target anatomy and transports them on the source anatomy.

(642,)
fig = plt.figure(figsize=(3 * 3, 3))
fig.suptitle("Transporting feature maps of the training set")
grid_spec = gridspec.GridSpec(1, 3, figure=fig)

ax = fig.add_subplot(grid_spec[0, 0], projection="3d")
ax.set_title("Actual source features")
plot_surface_map(
    source_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)

ax = fig.add_subplot(grid_spec[0, 1], projection="3d")
ax.set_title("Predicted target features")
plot_surface_map(predicted_target_features, axes=ax, vmax=10, vmin=-10)

ax = fig.add_subplot(grid_spec[0, 2], projection="3d")
ax.set_title("Actual target features")
plot_surface_map(
    target_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)

plt.show()
Transporting feature maps of the training set, Actual source features, Predicted target features, Actual target features

Here, we transported a feature map which is part of the traning set, which does not really help evaluate the quality of our model. Instead, we can also use the computed mapping to transport unseen data, which is how we will usually assess whether our model has captured useful information or not:

contrast_index = len(contrasts) - 1
predicted_target_features = mapping.transform(
    source_features[contrast_index, :]
)

fig = plt.figure(figsize=(3 * 3, 3))
fig.suptitle("Transporting feature maps of the test set")
grid_spec = gridspec.GridSpec(1, 3, figure=fig)

ax = fig.add_subplot(grid_spec[0, 0], projection="3d")
ax.set_title("Actual source features")
plot_surface_map(
    source_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)

ax = fig.add_subplot(grid_spec[0, 1], projection="3d")
ax.set_title("Predicted target features")
plot_surface_map(predicted_target_features, axes=ax, vmax=10, vmin=-10)

ax = fig.add_subplot(grid_spec[0, 2], projection="3d")
ax.set_title("Actual target features")
plot_surface_map(
    target_features[contrast_index, :], axes=ax, vmax=10, vmin=-10
)

plt.show()
Transporting feature maps of the test set, Actual source features, Predicted target features, Actual target features

Total running time of the script: ( 3 minutes 46.208 seconds)

Estimated memory usage: 205 MB

Gallery generated by Sphinx-Gallery